{ "cells": [ { "cell_type": "markdown", "source": [ "# Simple Pipeline\n", "\n", "In this notebook, we show a full example of how the `fairret` library might be used to train a PyTorch model with a fairness cost." ], "metadata": { "collapsed": false }, "id": "46617c2b3a734bfb" }, { "cell_type": "markdown", "source": [ "## Loading some data\n", "To start, let's load some data where fair binary classification is desirable. We'll use the `folktables` [library](https://github.com/socialfoundations/folktables) and their example data of the 2018 [American Community Survey](https://www.census.gov/programs-surveys/acs) (ACS)." ], "metadata": { "collapsed": false }, "id": "bb38ec7e99e970b8" }, { "cell_type": "code", "execution_count": 1, "outputs": [], "source": [ "from folktables import ACSDataSource\n", "\n", "data_source = ACSDataSource(survey_year='2018', horizon='1-Year', survey='person')\n", "data = data_source.get_data(states=[\"AL\"], download=True)" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-07-23T15:10:07.745118500Z", "start_time": "2024-07-23T15:10:06.635591900Z" } }, "id": "f9d1ac35f79de8f7" }, { "cell_type": "raw", "source": [ "We specifically address the ACSIncome task, where we predict whether an individual's income is above $50,000." ], "metadata": { "collapsed": false }, "id": "95e7af7af46a593e" }, { "cell_type": "code", "execution_count": 2, "outputs": [ { "data": { "text/plain": " AGEP WKHP \\\n0 18 21.0 \n1 53 40.0 \n2 41 40.0 \n3 18 2.0 \n4 21 50.0 \n\n COW_Employee of a private for-profit company or business, or of an individual, for wages, salary, or commissions \\\n0 True \n1 False \n2 True \n3 False \n4 False \n\n COW_Employee of a private not-for-profit, tax-exempt, or charitable organization \\\n0 False \n1 False \n2 False \n3 False \n4 False \n\n COW_Federal government employee \\\n0 False \n1 True \n2 False \n3 False \n4 True \n\n COW_Local government employee (city, county, etc.) \\\n0 False \n1 False \n2 False \n3 False \n4 False \n\n COW_Self-employed in own incorporated business, professional practice or farm \\\n0 False \n1 False \n2 False \n3 False \n4 False \n\n COW_Self-employed in own not incorporated business, professional practice, or farm \\\n0 False \n1 False \n2 False \n3 True \n4 False \n\n COW_State government employee \\\n0 False \n1 False \n2 False \n3 False \n4 False \n\n COW_Working without pay in family business or farm ... SEX_Male \\\n0 False ... False \n1 False ... True \n2 False ... True \n3 False ... False \n4 False ... True \n\n RAC1P_Alaska Native alone RAC1P_American Indian alone \\\n0 False False \n1 False False \n2 False False \n3 False False \n4 False False \n\n RAC1P_American Indian and Alaska Native tribes specified; or American Indian or Alaska Native, not specified and no other races \\\n0 False \n1 False \n2 False \n3 False \n4 False \n\n RAC1P_Asian alone RAC1P_Black or African American alone \\\n0 False True \n1 False False \n2 False False \n3 False False \n4 False False \n\n RAC1P_Native Hawaiian and Other Pacific Islander alone \\\n0 False \n1 False \n2 False \n3 False \n4 False \n\n RAC1P_Some Other Race alone RAC1P_Two or More Races RAC1P_White alone \n0 False False False \n1 False False True \n2 False False True \n3 False False True \n4 False False True \n\n[5 rows x 729 columns]", "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
AGEPWKHPCOW_Employee of a private for-profit company or business, or of an individual, for wages, salary, or commissionsCOW_Employee of a private not-for-profit, tax-exempt, or charitable organizationCOW_Federal government employeeCOW_Local government employee (city, county, etc.)COW_Self-employed in own incorporated business, professional practice or farmCOW_Self-employed in own not incorporated business, professional practice, or farmCOW_State government employeeCOW_Working without pay in family business or farm...SEX_MaleRAC1P_Alaska Native aloneRAC1P_American Indian aloneRAC1P_American Indian and Alaska Native tribes specified; or American Indian or Alaska Native, not specified and no other racesRAC1P_Asian aloneRAC1P_Black or African American aloneRAC1P_Native Hawaiian and Other Pacific Islander aloneRAC1P_Some Other Race aloneRAC1P_Two or More RacesRAC1P_White alone
01821.0TrueFalseFalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseTrueFalseFalseFalseFalse
15340.0FalseFalseTrueFalseFalseFalseFalseFalse...TrueFalseFalseFalseFalseFalseFalseFalseFalseTrue
24140.0TrueFalseFalseFalseFalseFalseFalseFalse...TrueFalseFalseFalseFalseFalseFalseFalseFalseTrue
3182.0FalseFalseFalseFalseFalseTrueFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseTrue
42150.0FalseFalseTrueFalseFalseFalseFalseFalse...TrueFalseFalseFalseFalseFalseFalseFalseFalseTrue
\n

5 rows × 729 columns

\n
" }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from folktables import ACSIncome, generate_categories\n", "\n", "definition_df = data_source.get_definitions(download=True)\n", "categories = generate_categories(features=ACSIncome.features, definition_df=definition_df)\n", "\n", "df_feat, df_labels, _ = ACSIncome.df_to_pandas(data, categories=categories, dummies=True)\n", "df_feat.head()" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-07-23T15:10:08.662780500Z", "start_time": "2024-07-23T15:10:07.745118500Z" } }, "id": "a58a200cf2c54dc5" }, { "cell_type": "markdown", "source": [ "To keep things simple for now, let's only consider two sensitive groups: *male* and *female*." ], "metadata": { "collapsed": false }, "id": "6b66945992b5370" }, { "cell_type": "code", "execution_count": 3, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0.47808514 0.52191486]\n" ] } ], "source": [ "sens_cols = ['SEX_Female', 'SEX_Male']\n", "feat = df_feat.drop(columns=sens_cols).to_numpy(dtype=\"float\")\n", "sens = df_feat[sens_cols].to_numpy(dtype=\"float\")\n", "label = df_labels.to_numpy(dtype=\"float\")\n", "\n", "print(sens.mean(axis=0))" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-07-23T15:10:09.749483900Z", "start_time": "2024-07-23T15:10:08.666619700Z" } }, "id": "a58012f8494eb69c" }, { "cell_type": "markdown", "source": [ "## A naive PyTorch pipeline\n", "\n", "The `fairret` library treats sensitive features in the same way 'normal' features are treated in PyTorch: as (N x D) tensors, where N is the number of samples and D the dimensionality. In contrast to other fairness libraries you may have used, we can therefore just leave categorical sensitive features as one-hot encoded!" ], "metadata": { "collapsed": false }, "id": "2e4fc2e65d3eb308" }, { "cell_type": "code", "execution_count": 4, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Shape of the 'normal' features tensor: torch.Size([22268, 727])\n", "Shape of the sensitive features tensor: torch.Size([22268, 2])\n", "Shape of the labels tensor: torch.Size([22268, 1])\n" ] } ], "source": [ "import torch\n", "torch.manual_seed(0)\n", "feat, sens, label = torch.tensor(feat).float(), torch.tensor(sens).float(), torch.tensor(label).float()\n", "print(f\"Shape of the 'normal' features tensor: {feat.shape}\")\n", "print(f\"Shape of the sensitive features tensor: {sens.shape}\")\n", "print(f\"Shape of the labels tensor: {label.shape}\")" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-07-23T15:10:12.628364200Z", "start_time": "2024-07-23T15:10:09.739634Z" } }, "id": "708d7de5a72046d2" }, { "cell_type": "markdown", "source": [ "In typical PyTorch fashion, let's now define a simple neural net with 1 hidden layer, an optimizer, and a DataLoader." ], "metadata": { "collapsed": false }, "id": "c2b2b36e3b6b1c4d" }, { "cell_type": "code", "execution_count": 5, "outputs": [], "source": [ "h_layer_dim = 16\n", "lr = 1e-3\n", "batch_size = 1024\n", "\n", "model = torch.nn.Sequential(\n", " torch.nn.Linear(feat.shape[1], h_layer_dim),\n", " torch.nn.ReLU(),\n", " torch.nn.Linear(h_layer_dim, 1)\n", ")\n", "optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n", "\n", "from torch.utils.data import TensorDataset, DataLoader\n", "dataset = TensorDataset(feat, sens, label)\n", "dataloader = DataLoader(dataset, batch_size=batch_size)" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-07-23T15:10:13.337370400Z", "start_time": "2024-07-23T15:10:12.623100Z" } }, "id": "3b1d4337167c413e" }, { "cell_type": "markdown", "source": [ "Now, let's train it without doing any fairness adjustment..." ], "metadata": { "collapsed": false }, "id": "4a9e4d246d0f16af" }, { "cell_type": "code", "execution_count": 6, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 0, loss: 0.6495096764781259\n", "Epoch: 1, loss: 0.631090676242655\n", "Epoch: 2, loss: 0.6120786558498036\n", "Epoch: 3, loss: 0.5901886902072213\n", "Epoch: 4, loss: 0.5662552903998982\n", "Epoch: 5, loss: 0.5412234474312175\n", "Epoch: 6, loss: 0.5169245397502725\n", "Epoch: 7, loss: 0.4951953955672004\n", "Epoch: 8, loss: 0.4771566512909802\n", "Epoch: 9, loss: 0.4624161679636348\n", "Epoch: 10, loss: 0.45037723264910956\n", "Epoch: 11, loss: 0.4405312815850431\n", "Epoch: 12, loss: 0.43243050778454\n", "Epoch: 13, loss: 0.42573171854019165\n", "Epoch: 14, loss: 0.4201531301845204\n", "Epoch: 15, loss: 0.4154460768808018\n", "Epoch: 16, loss: 0.41148471154949884\n", "Epoch: 17, loss: 0.40810865570198407\n", "Epoch: 18, loss: 0.4051740806211125\n", "Epoch: 19, loss: 0.4026364270936359\n", "Epoch: 20, loss: 0.400410919026895\n", "Epoch: 21, loss: 0.39844207533381204\n", "Epoch: 22, loss: 0.3967087905515324\n", "Epoch: 23, loss: 0.3951598744500767\n", "Epoch: 24, loss: 0.3937697451223027\n", "Epoch: 25, loss: 0.39251258156516333\n", "Epoch: 26, loss: 0.3913724842396649\n", "Epoch: 27, loss: 0.3903204453262416\n", "Epoch: 28, loss: 0.3893602354960008\n", "Epoch: 29, loss: 0.38850402154705743\n", "Epoch: 30, loss: 0.3876958462325009\n", "Epoch: 31, loss: 0.38694334572011774\n", "Epoch: 32, loss: 0.3862581164999442\n", "Epoch: 33, loss: 0.38562133434143936\n", "Epoch: 34, loss: 0.3850080926309932\n", "Epoch: 35, loss: 0.38443617725914175\n", "Epoch: 36, loss: 0.38391522047194565\n", "Epoch: 37, loss: 0.38343126529997046\n", "Epoch: 38, loss: 0.3829596929929473\n", "Epoch: 39, loss: 0.3825045864690434\n", "Epoch: 40, loss: 0.3820722380822355\n", "Epoch: 41, loss: 0.38168060102246026\n", "Epoch: 42, loss: 0.38129559497941623\n", "Epoch: 43, loss: 0.3809025511145592\n", "Epoch: 44, loss: 0.3805498616261916\n", "Epoch: 45, loss: 0.38022591309113934\n", "Epoch: 46, loss: 0.3798875537785617\n", "Epoch: 47, loss: 0.379513679580255\n", "Epoch: 48, loss: 0.3791568949818611\n", "Epoch: 49, loss: 0.37884155728600244\n" ] } ], "source": [ "import numpy as np\n", "\n", "nb_epochs = 50\n", "\n", "for epoch in range(nb_epochs):\n", " losses = []\n", " for batch_feat, batch_sens, batch_label in dataloader:\n", " optimizer.zero_grad()\n", " \n", " logit = model(batch_feat)\n", " loss = torch.nn.functional.binary_cross_entropy_with_logits(logit, batch_label)\n", " loss.backward()\n", " \n", " optimizer.step()\n", " losses.append(loss.item())\n", " print(f\"Epoch: {epoch}, loss: {np.mean(losses)}\")" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-07-23T15:10:37.258908300Z", "start_time": "2024-07-23T15:10:13.337370400Z" } }, "id": "cd4ba74cb852b8d4" }, { "cell_type": "markdown", "source": [ "## Bias analysis in fairret\n", "\n", "Can we detect any statistical disparities (biases) in the naive model?\n", "\n", "The `fairret` library assesses these biases by comparing a (linear-fractional) Statistic computed for each sensitive features. In our example, this is for the 'SEX_Female' and 'SEX_Male' features. For example, let's look at the true positive rate (= the recall or sensitivity)." ], "metadata": { "collapsed": false }, "id": "9ee41dc8676fa3b8" }, { "cell_type": "code", "execution_count": 7, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The TruePositiveRate for group SEX_Female is 0.5624983310699463\n", "The TruePositiveRate for group SEX_Male is 0.6300471425056458\n", "The absolute difference is 0.06754881143569946\n" ] } ], "source": [ "from fairret.statistic import TruePositiveRate\n", "\n", "statistic = TruePositiveRate()\n", "\n", "pred = torch.sigmoid(model(feat))\n", "stat_per_group = statistic(pred, sens, label)\n", "absolute_diff = torch.abs(stat_per_group[0] - stat_per_group[1])\n", "\n", "print(f\"The {statistic.__class__.__name__} for group {sens_cols[0]} is {stat_per_group[0]}\")\n", "print(f\"The {statistic.__class__.__name__} for group {sens_cols[1]} is {stat_per_group[1]}\")\n", "print(f\"The absolute difference is {torch.abs(stat_per_group[0] - stat_per_group[1])}\")" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-07-23T15:10:37.283551800Z", "start_time": "2024-07-23T15:10:37.259404700Z" } }, "id": "d9c2d8f9f763a563" }, { "cell_type": "markdown", "source": [ "## Bias mitigation in fairret\n", "\n", "To reduce the statistical disparity we found, we can use one of the fairrets implemented in the library. To quantify bias according to the correct statistic, we need to pass the statistic object to the fairret loss." ], "metadata": { "collapsed": false }, "id": "4e91b13a911a6963" }, { "cell_type": "code", "execution_count": 8, "outputs": [], "source": [ "from fairret.loss import NormLoss\n", "\n", "norm_loss = NormLoss(statistic)" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-07-23T15:10:37.485369200Z", "start_time": "2024-07-23T15:10:37.286512500Z" } }, "id": "c794c8c6dac3bea9" }, { "cell_type": "markdown", "source": [ "Let's train another model where we now add this loss term to the objective. \n", "\n", "**We only need to add one line of code to the standard PyTorch training loop!**" ], "metadata": { "collapsed": false }, "id": "816cf465d52b6ed1" }, { "cell_type": "code", "execution_count": 9, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 0, loss: 0.6555492119355635\n", "Epoch: 1, loss: 0.6289367729967291\n", "Epoch: 2, loss: 0.6124052486636422\n", "Epoch: 3, loss: 0.5947602743452246\n", "Epoch: 4, loss: 0.5761971880089153\n", "Epoch: 5, loss: 0.5571780936284498\n", "Epoch: 6, loss: 0.538903527639129\n", "Epoch: 7, loss: 0.5221295695413243\n", "Epoch: 8, loss: 0.5069978806105527\n", "Epoch: 9, loss: 0.4935957017270001\n", "Epoch: 10, loss: 0.48187787559899414\n", "Epoch: 11, loss: 0.4717836068435149\n", "Epoch: 12, loss: 0.46291052482344885\n", "Epoch: 13, loss: 0.4552074447274208\n", "Epoch: 14, loss: 0.44845194775949826\n", "Epoch: 15, loss: 0.4425306184725328\n", "Epoch: 16, loss: 0.43727567114613275\n", "Epoch: 17, loss: 0.4326496706767516\n", "Epoch: 18, loss: 0.4285395748235963\n", "Epoch: 19, loss: 0.42487049102783203\n", "Epoch: 20, loss: 0.42168024250052194\n", "Epoch: 21, loss: 0.418713163245808\n", "Epoch: 22, loss: 0.4161644611846317\n", "Epoch: 23, loss: 0.41383336958560074\n", "Epoch: 24, loss: 0.4117328572002324\n", "Epoch: 25, loss: 0.4098751870068637\n", "Epoch: 26, loss: 0.4081542722203515\n", "Epoch: 27, loss: 0.4065771651538936\n", "Epoch: 28, loss: 0.40517762438817456\n", "Epoch: 29, loss: 0.40386907011270523\n", "Epoch: 30, loss: 0.40267594226382\n", "Epoch: 31, loss: 0.4015895886854692\n", "Epoch: 32, loss: 0.40053029832514847\n", "Epoch: 33, loss: 0.3996298699216409\n", "Epoch: 34, loss: 0.398726802657951\n", "Epoch: 35, loss: 0.39790612052787433\n", "Epoch: 36, loss: 0.3971890244971622\n", "Epoch: 37, loss: 0.3964132788506421\n", "Epoch: 38, loss: 0.3957923088561405\n", "Epoch: 39, loss: 0.3951647024263035\n", "Epoch: 40, loss: 0.39456393027847464\n", "Epoch: 41, loss: 0.3940161994912408\n", "Epoch: 42, loss: 0.39341962202028796\n", "Epoch: 43, loss: 0.39296053688634525\n", "Epoch: 44, loss: 0.3924828178503297\n", "Epoch: 45, loss: 0.3919635293158618\n", "Epoch: 46, loss: 0.3915071880275553\n", "Epoch: 47, loss: 0.39117465845563193\n", "Epoch: 48, loss: 0.390705829994245\n", "Epoch: 49, loss: 0.3903096318244934\n" ] } ], "source": [ "fairness_strength = 0.1\n", "model = torch.nn.Sequential(\n", " torch.nn.Linear(feat.shape[1], h_layer_dim),\n", " torch.nn.ReLU(),\n", " torch.nn.Linear(h_layer_dim, 1)\n", ")\n", "optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n", "\n", "for epoch in range(nb_epochs):\n", " losses = []\n", " for batch_feat, batch_sens, batch_label in dataloader:\n", " optimizer.zero_grad()\n", " \n", " logit = model(batch_feat)\n", " loss = torch.nn.functional.binary_cross_entropy_with_logits(logit, batch_label)\n", " loss += fairness_strength * norm_loss(logit, batch_sens, batch_label)\n", " loss.backward()\n", " \n", " optimizer.step()\n", " losses.append(loss.item())\n", " print(f\"Epoch: {epoch}, loss: {np.mean(losses)}\")" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-07-23T15:11:03.003056Z", "start_time": "2024-07-23T15:10:37.489355500Z" } }, "id": "e8e1f54235f8b623" }, { "cell_type": "markdown", "source": [ "Let's check the true positive rate per group again..." ], "metadata": { "collapsed": false }, "id": "8c405074b53ccaea" }, { "cell_type": "code", "execution_count": 10, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The TruePositiveRate for group SEX_Female is 0.5829195976257324\n", "The TruePositiveRate for group SEX_Male is 0.6023395657539368\n", "The absolute difference is 0.019419968128204346\n" ] } ], "source": [ "pred = torch.sigmoid(model(feat))\n", "stat_per_group = statistic(pred, sens, label)\n", "\n", "print(f\"The {statistic.__class__.__name__} for group {sens_cols[0]} is {stat_per_group[0]}\")\n", "print(f\"The {statistic.__class__.__name__} for group {sens_cols[1]} is {stat_per_group[1]}\")\n", "print(f\"The absolute difference is {torch.abs(stat_per_group[0] - stat_per_group[1])}\")" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-07-23T15:11:03.038099300Z", "start_time": "2024-07-23T15:11:03.005789100Z" } }, "id": "657c2799f5965e8a" }, { "cell_type": "markdown", "source": [ "With a small change, the absolute difference between the statistics was reduced from 6.8% to 1.9% !\n", "\n", "Though this was a simple example, it illustrates how powerful the `fairret` paradigm can be. \n", "\n", "Feel free to go back and try out some other statistics to compare or fairret losses to minimize. Both are designed to be easily interchangeable and extensible." ], "metadata": { "collapsed": false }, "id": "2ac5d9c371b69b49" } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 5 }